package cn.edu.bjtu.model.core;

import java.io.IOException;
import java.util.Date;
import java.util.concurrent.Future;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.textplantform.common.core.Deep4jModelType;
import org.textplantform.common.provider.NetworkModelService;
/**
 * 该类通过ThreadLocal的方式,每个线程分别维护自己的模型,分别调用自己的模型方法获取最后的结果,
 * 将DEE4J中原来线程非安全的类,通过线程隔离的变成安全的
 * 该模型可以在并发的环境下调用
 * 
 * @author Alex
 *
 */
class ComputationGraphSlave extends Slave{
		ThreadLocal<Long> threadModelLastModified = new ThreadLocal<Long>();
		ThreadLocal<ComputationGraph> threadModelBackup = new ThreadLocal<ComputationGraph>();
		ThreadLocal<ComputationGraph> threadModel = new ThreadLocal<ComputationGraph>(){
			//初始化错误的话直接返回空值,没出错的话返回实际值
			protected ComputationGraph initialValue() {
				return null;
//				try{
//					ComputationGraph cg = loadModel();
//					if(cg != null){
//						//成功
//						threadModelLastModified.set(new Date().getTime());
//					}
//					return cg;
//				}catch (Exception e) {
//					logException(e);
//					//分类线程 初次加载的时候出错了 加载不到模型
//					return null;
//				}finally{
//				}
			};
			public ComputationGraph get() {
				if(threadModelLastModified.get() == null || threadModelLastModified.get()<ComputationGraphSlave.this.modelLastModifiedTime){
					//当前线程的模型最后更新时间小于网络轮询得到的时间,说明过期,需要重要加载
					ComputationGraph cg = super.get();
					if(cg == null){
						try{
							//第一次加载
							cg = loadModel();
							super.set(cg);
							threadModelLastModified.set(new Date().getTime());
							return cg;
						}catch(Exception e){
							e.printStackTrace();
							return null;
						}
					}
					threadModelBackup.set(super.get());
					try{
						//获取最新模型
						cg = loadModel();
						super.set(cg);
						//更新时间
						threadModelLastModified.set(ComputationGraphSlave.this.modelLastModifiedTime);
						//备份清空
						threadModelBackup.set(null);
					}catch (Exception e) {
						//失败的话,恢复当前线程模型
						super.set(threadModelBackup.get());
					}
				}
				return super.get();
			};
		};
		
		private ComputationGraph loadModel() throws IOException{
			 logger.info(" Loading model  ......    ");
			 return ModelSerializer.restoreComputationGraph(getStream());
		}
		
		public ComputationGraphSlave(NetworkModelService nms) {
			super(nms,true);
		}
		//Sync
		public INDArray[] output(INDArray... input) {
			ComputationGraph threadLocalModel = threadModel.get();
			return threadLocalModel == null?DumpSlave.ALLONE.output(input):threadLocalModel.output(input);
		}
		
		//Sync
		public Evaluation evaluate(DataSetIterator iterator) {
			ComputationGraph threadLocalModel = threadModel.get();
			return threadLocalModel == null?DumpSlave.ALLONE.evaluate(iterator):threadLocalModel.evaluate(iterator);
		}
		@Override
		public Deep4jModelType getActualType() {
			return Deep4jModelType.ComputationGraph;
		}		
		@Override
		public Object get() {
			return threadModel.get();
		}
		@Override
		public void load() throws Exception {
			//暂时不加载,等有调用evaluate方法的线程时,让它们去加载  
			//do noting
			logger.info(" === Lazy Loading === ");
		}
		//Async
		@Override
		public Future<Evaluation> evaluateAsync(DataSetIterator iterator) throws Exception {
			throw new UnsupportedOperationException(" ComputationGraphSlave 不支持异步调用    ");
		}
		//Async
		@Override
		public Future<INDArray[]> outputAsync(INDArray... input) throws Exception {
			throw new UnsupportedOperationException(" ComputationGraphSlave 不支持异步调用    ");
		}
		
	}